package com.chatplus.application.controller.api;

import cn.dev33.satoken.annotation.SaIgnore;
import cn.hutool.core.collection.CollUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.chatplus.application.aiprocessor.platform.image.mj.MjImageProcessor;
import com.chatplus.application.common.annotation.RateLimiter;
import com.chatplus.application.common.enumeration.LimitType;
import com.chatplus.application.common.exception.BadRequestException;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.common.page.PageParam;
import com.chatplus.application.domain.entity.draw.MjJobEntity;
import com.chatplus.application.domain.request.ImagePublishRequest;
import com.chatplus.application.domain.request.MjActionRequest;
import com.chatplus.application.domain.request.MjCallbackNotifyRequest;
import com.chatplus.application.domain.request.MjImageJobRequest;
import com.chatplus.application.domain.response.MjJobDetailApiResponse;
import com.chatplus.application.domain.response.PlusPageResponse;
import com.chatplus.application.enumeration.ImageTaskTypeEnum;
import com.chatplus.application.enumeration.MjJobStatusEnum;
import com.chatplus.application.service.draw.MjJobService;
import com.chatplus.application.util.BaiduTransUtil;
import com.chatplus.application.web.basecontroller.BaseController;
import com.google.common.collect.Lists;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import org.apache.commons.lang3.StringUtils;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * MJ API
 *
 * @author Angus
 */
@Validated
@RestController
@RequestMapping("/api/mj")
@Tag(name = "MJ API", description = "MJ API")
public class MjJobApiController extends BaseController {

    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(MjJobApiController.class);
    private final MjJobService mjJobService;
    private final MjImageProcessor mjImageProcessor;

    public MjJobApiController(MjJobService mjJobService,
                              MjImageProcessor mjImageProcessor) {
        this.mjJobService = mjJobService;
        this.mjImageProcessor = mjImageProcessor;
    }

    /**
     * 获取任务列表或者创作记录
     *
     * @return 任务列表
     */
    @GetMapping("/jobs")
    @Operation(summary = "获取任务列表或者创作记录")
    public List<MjJobDetailApiResponse> jobs(@RequestParam(value = "status") Integer status,
                                             @RequestParam(value = "page", required = false, defaultValue = "1") Integer page,
                                             @RequestParam(value = "page_size", required = false, defaultValue = "10") Integer pageSize) {
        PageParam pageParam = new PageParam(page, pageSize);
        if (page == 0 && pageSize == 0) {
            page = 1;
            pageSize = 20;
        }
        if (status == 0) {
            return runningJobs(page, pageSize);
        }
        if (status == 1) {
            Page<MjJobEntity> pageDTO = mjJobService.getMjJobPage(pageParam, true, getUserId(), null);
            if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
                return Collections.emptyList();
            }
            return pageDTO.getRecords().stream().map(MjJobDetailApiResponse::build).toList();
        }
        return Collections.emptyList();
    }

    /**
     * 获取任务列表或者创作记录
     */
    @GetMapping("/runningJobs")
    @Operation(summary = "获取运行中的任务记录")
    public List<MjJobDetailApiResponse> runningJobs(@RequestParam(value = "page", required = false, defaultValue = "0") Integer page,
                                                    @RequestParam(value = "page_size", required = false, defaultValue = "0") Integer pageSize) {
        PageParam pageParam = new PageParam(page, pageSize);
        if (page == 0 && pageSize == 0) {
            pageParam = new PageParam(1, Math.toIntExact(mjJobService.getRunningJobCount(getUserId())) + 10);
        }
        Page<MjJobEntity> pageDTO = mjJobService.getMjJobPage(pageParam, false, getUserId(), null);
        if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
            return Collections.emptyList();
        }
        return pageDTO.getRecords().stream().map(MjJobDetailApiResponse::build).toList();
    }

    @GetMapping("/finishJobs")
    @Operation(summary = "获取已完成的记录")
    public PlusPageResponse<MjJobDetailApiResponse> finishJobs(@RequestParam(value = "page") Integer page,
                                                               @RequestParam(value = "page_size") Integer pageSize) {
        PageParam pageParam = new PageParam(page, pageSize);
        Page<MjJobEntity> pageDTO = mjJobService.getMjJobPage(pageParam, true, getUserId(), null);
        if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
            return new PlusPageResponse<>(page, pageSize, 0, 0, Lists.newArrayList());
        }
        List<MjJobDetailApiResponse> responseList = pageDTO.getRecords().stream().map(MjJobDetailApiResponse::build).toList();
        return new PlusPageResponse<>(pageParam.getCurrent(), pageParam.getSize()
                , pageDTO.getPages(), pageDTO.getTotal(), responseList);
    }

    @GetMapping("/imgWall")
    @Operation(summary = "获取任务列表或者创作记录")
    @SaIgnore
    public List<MjJobDetailApiResponse> imgWall(@RequestParam(value = "page", required = false) Integer page,
                                                @RequestParam(value = "page_size", required = false) Integer pageSize) {
        PageParam pageParam = new PageParam(page, pageSize);
        // 获取成功的任务
        Page<MjJobEntity> pageDTO = mjJobService.getMjJobPage(pageParam, true, null, true);
        if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
            return Collections.emptyList();
        }
        return pageDTO.getRecords().stream().map(MjJobDetailApiResponse::build).toList();
    }

    /**
     * 画图任务
     */
    @PostMapping("/image")
    @Operation(summary = "画图任务")
    @RateLimiter(count = 1, time = 5, limitType = LimitType.IP)
    public void image(@RequestBody MjImageJobRequest request) {
        String errMsg = mjImageProcessor.verifyUserRequest(getUserId(), null);
        if (StringUtils.isNotEmpty(errMsg)) {
            throw new BadRequestException(errMsg);
        }
        request.setPrompt(BaiduTransUtil.getTransResult(request.getPrompt(), null, "en"));
        initMjJobEntity(request);
        mjImageProcessor.process(request);
    }

    /**
     * 删除图片
     */
    @PostMapping("/remove")
    @Operation(summary = "删除图片")
    public void remove(@RequestBody ImagePublishRequest request) {
        LambdaQueryWrapper<MjJobEntity> wrapper = new LambdaQueryWrapper<MjJobEntity>()
                .eq(MjJobEntity::getId, request.getId())
                .eq(MjJobEntity::getUserId, getUserId());
        mjJobService.remove(wrapper);
    }

    @PostMapping("/upscale")
    @Operation(summary = "用于选择已生成的图像")
    @RateLimiter(count = 1, time = 5, limitType = LimitType.IP)
    public void upscale(@RequestBody @Valid MjActionRequest request) {
        String errMsg = mjImageProcessor.verifyUserRequest(getUserId(), ImageTaskTypeEnum.UPSCALE);
        if (StringUtils.isNotEmpty(errMsg)) {
            throw new BadRequestException(errMsg);
        }
        MjJobEntity mjJobEntity = mjJobService.getByTaskId(request.getTaskId());
        if (mjJobEntity == null) {
            throw new BadRequestException("任务不存在");
        }
        MjImageJobRequest mjImageJobRequest = new MjImageJobRequest();
        mjImageJobRequest.setTaskType(ImageTaskTypeEnum.UPSCALE);
        mjImageJobRequest.setPrompt(request.getPrompt());
        mjImageJobRequest.setHandleTaskId(request.getTaskId());
        mjImageJobRequest.setMessageHash(mjJobEntity.getHash());
        mjImageJobRequest.setIndex(request.getIndex());
        mjImageJobRequest.setReferenceId(mjJobEntity.getId());
        initMjJobEntity(mjImageJobRequest);
        mjImageProcessor.process(mjImageJobRequest);
    }

    @PostMapping("/variation")
    @Operation(summary = "用于在选中的图像基础上生成四张新的、风格相近的图像")
    @RateLimiter(count = 1, time = 5, limitType = LimitType.IP)
    public void variation(@RequestBody @Valid MjActionRequest request) {
        String errMsg = mjImageProcessor.verifyUserRequest(getUserId(), ImageTaskTypeEnum.VARIATION);
        if (StringUtils.isNotEmpty(errMsg)) {
            throw new BadRequestException(errMsg);
        }
        MjJobEntity mjJobEntity = mjJobService.getByTaskId(request.getTaskId());
        if (mjJobEntity == null) {
            throw new BadRequestException("任务不存在");
        }
        MjImageJobRequest mjImageJobRequest = new MjImageJobRequest();
        mjImageJobRequest.setTaskType(ImageTaskTypeEnum.VARIATION);
        mjImageJobRequest.setPrompt(request.getPrompt());
        mjImageJobRequest.setHandleTaskId(request.getTaskId());
        mjImageJobRequest.setMessageHash(mjJobEntity.getHash());
        mjImageJobRequest.setIndex(request.getIndex());
        mjImageJobRequest.setReferenceId(mjJobEntity.getId());
        initMjJobEntity(mjImageJobRequest);
        mjImageProcessor.process(mjImageJobRequest);
    }

    @PostMapping("/action")
    @Operation(summary = "用于在选中的图像基础上生成四张新的、风格相近的图像")
    @RateLimiter(count = 1, time = 5, limitType = LimitType.IP)
    public void action(@RequestBody @Valid MjActionRequest request) {
        String errMsg = mjImageProcessor.verifyUserRequest(getUserId(), ImageTaskTypeEnum.ACTION);
        if (StringUtils.isNotEmpty(errMsg)) {
            throw new BadRequestException(errMsg);
        }
        MjJobEntity mjJobEntity = mjJobService.getByTaskId(request.getTaskId());
        if (mjJobEntity == null) {
            throw new BadRequestException("任务不存在");
        }
        MjImageJobRequest mjImageJobRequest = new MjImageJobRequest();
        mjImageJobRequest.setTaskType(ImageTaskTypeEnum.ACTION);
        mjImageJobRequest.setPrompt(request.getPrompt());
        mjImageJobRequest.setHandleTaskId(request.getTaskId());
        mjImageJobRequest.setMessageHash(mjJobEntity.getHash());
        mjImageJobRequest.setCustomId(request.getCustomId());
        mjImageJobRequest.setIndex(request.getIndex());
        mjImageJobRequest.setReferenceId(mjJobEntity.getId());
        initMjJobEntity(mjImageJobRequest);
        mjImageProcessor.process(mjImageJobRequest);
    }

    private void initMjJobEntity(MjImageJobRequest request) {
        MjJobEntity mjJobEntity = new MjJobEntity();
        mjJobEntity.setUserId(getUserId());
        mjJobEntity.setType(request.getTaskType());
        mjJobEntity.setProgress(0);
        mjJobEntity.setPrompt(request.getPrompt());
        mjJobEntity.setTaskId("-1");
        mjJobEntity.setReferenceId(request.getReferenceId());
        List<String> imgArr = new ArrayList<>();
        if (CollUtil.isNotEmpty(request.getImgArr())) {
            imgArr.addAll(request.getImgArr());
        }
        if (StringUtils.isNotEmpty(request.getImg())) {
            imgArr.add(request.getImg());
        }
        mjJobEntity.setPublish(false);
        mjJobEntity.setOrgUrl(imgArr);
        mjJobService.save(mjJobEntity);
        request.setMjJobId(mjJobEntity.getId());
        request.setUserId(getUserId());

    }

    /**
     * 发布图片
     */
    @PostMapping("/publish")
    @Operation(summary = "发布图片")
    public void publish(@RequestBody ImagePublishRequest request) {
        mjJobService.update(new LambdaUpdateWrapper<MjJobEntity>()
                .set(MjJobEntity::getPublish, request.isAction())
                .eq(MjJobEntity::getId, request.getId()));
    }

    // 回调接口
    @PostMapping("/callbackNotify")
    @Operation(summary = "回调接口")
    @SaIgnore
    public void callbackNotify(@RequestBody MjCallbackNotifyRequest request) {
        LOGGER.message("收到MJ回调通知").context("request", request).info();
        String taskId = request.getId();
        if (StringUtils.isEmpty(taskId)) {
            LOGGER.message("回调通知参数错误").error();
            return;
        }
        String state = request.getState();
        MjJobEntity mjJobEntity = null;
        if (StringUtils.isNotEmpty(state)) {
            try {
                mjJobEntity = mjJobService.getById(Long.valueOf(state));
            } catch (Exception ignored) {
            }
        }
        if (mjJobEntity == null) {
            mjJobEntity = mjJobService.getByTaskId(taskId);
        }
        if (mjJobEntity == null) {
            LOGGER.message("绘画任务不存在或已完成处理").context("state", "state").context("taskId", taskId).warn();
            return;
        }
        if (mjJobEntity.getProgress() == 100
                && mjJobEntity.getStatus() == MjJobStatusEnum.SUCCESS
                && StringUtils.isNotEmpty(mjJobEntity.getImgUrl())) {
            LOGGER.message("绘画任务已完成处理").context("state", state).context("taskId", taskId).warn();
            return;
        }
        mjJobService.handleNotify(mjJobEntity, request);
        mjImageProcessor.notifyUpdateTask(mjJobEntity.getId());
    }
}
